// Copyright Epic Games, Inc. All Rights Reserved.

#include "bench.h"

#include <zenbase/zenbase.h>
#include <zencore/except.h>

#if ZEN_PLATFORM_WINDOWS
#	include <stdio.h>
#	include <tchar.h>
#	include <windows.h>
#	include <exception>
#	include <fmt/format.h>

namespace zen::bench::util {

// See https://www.geoffchappell.com/studies/windows/km/ntoskrnl/api/ex/sysinfo/set.htm

typedef DWORD NTSTATUS;

#	define NT_SUCCESS(Status)		  (((NTSTATUS)(Status)) >= 0)
#	define STATUS_PRIVILEGE_NOT_HELD ((NTSTATUS)0xC0000061L)

typedef enum _SYSTEM_INFORMATION_CLASS
{
	SystemMemoryListInformation =
		80,	 // 80, q: SYSTEM_MEMORY_LIST_INFORMATION; s: SYSTEM_MEMORY_LIST_COMMAND (requires SeProfileSingleProcessPrivilege)
} SYSTEM_INFORMATION_CLASS;

// private
typedef enum _SYSTEM_MEMORY_LIST_COMMAND
{
	MemoryCaptureAccessedBits,
	MemoryCaptureAndResetAccessedBits,
	MemoryEmptyWorkingSets,
	MemoryFlushModifiedList,
	MemoryPurgeStandbyList,
	MemoryPurgeLowPriorityStandbyList,
	MemoryCommandMax
} SYSTEM_MEMORY_LIST_COMMAND;

BOOL
ObtainPrivilege(HANDLE TokenHandle, LPCSTR lpName, int flags)
{
	LUID			 Luid;
	TOKEN_PRIVILEGES CurrentPriv;
	TOKEN_PRIVILEGES NewPriv;

	DWORD dwBufferLength = 16;
	if (LookupPrivilegeValueA(0, lpName, &Luid))
	{
		NewPriv.PrivilegeCount			 = 1;
		NewPriv.Privileges[0].Luid		 = Luid;
		NewPriv.Privileges[0].Attributes = 0;

		if (AdjustTokenPrivileges(TokenHandle,
								  0,
								  &NewPriv,
								  DWORD((LPBYTE) & (NewPriv.Privileges[1]) - (LPBYTE)&NewPriv),
								  &CurrentPriv,
								  &dwBufferLength))
		{
			CurrentPriv.PrivilegeCount			 = 1;
			CurrentPriv.Privileges[0].Luid		 = Luid;
			CurrentPriv.Privileges[0].Attributes = flags != 0 ? 2 : 0;

			return AdjustTokenPrivileges(TokenHandle, 0, &CurrentPriv, dwBufferLength, 0, 0);
		}
	}
	return FALSE;
}

typedef NTSTATUS(WINAPI* NtSetSystemInformationFn)(INT, PVOID, ULONG);
typedef NTSTATUS(WINAPI* NtQuerySystemInformationFn)(INT, PVOID, ULONG, PULONG);

void
EmptyStandByList()
{
	HMODULE NtDll = LoadLibrary(L"ntdll.dll");
	if (!NtDll)
	{
		zen::ThrowLastError("Could not LoadLibrary ntdll");
	}

	HANDLE hToken;

	if (!OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY | TOKEN_ADJUST_PRIVILEGES, &hToken))
	{
		zen::ThrowLastError("Could not open current process token");
	}

	if (!ObtainPrivilege(hToken, "SeProfileSingleProcessPrivilege", 1))
	{
		zen::ThrowLastError("Unable to obtain SeProfileSingleProcessPrivilege");
	}

	CloseHandle(hToken);

	NtSetSystemInformationFn   NtSetSystemInformation	= (NtSetSystemInformationFn)GetProcAddress(NtDll, "NtSetSystemInformation");
	NtQuerySystemInformationFn NtQuerySystemInformation = (NtQuerySystemInformationFn)GetProcAddress(NtDll, "NtQuerySystemInformation");

	if (!NtSetSystemInformation || !NtQuerySystemInformation)
	{
		throw std::runtime_error("Failed to look up required ntdll functions");
	}

	SYSTEM_MEMORY_LIST_COMMAND MemoryListCommand = MemoryPurgeStandbyList;
	NTSTATUS NtStatus = NtSetSystemInformation(SystemMemoryListInformation, &MemoryListCommand, sizeof(MemoryListCommand));

	if (NtStatus == STATUS_PRIVILEGE_NOT_HELD)
	{
		throw elevation_required_exception("Insufficient privileges to execute the memory list command");
	}
	else if (!NT_SUCCESS(NtStatus))
	{
		throw std::runtime_error(fmt::format("Unable to execute the memory list command (status={})", NtStatus));
	}
}

}  // namespace zen::bench::util

#else

namespace zen::bench::util {

void
EmptyStandByList()
{
	return;
}

}  // namespace zen::bench::util

#endif
